from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy

from core.components.base_model import BaseModel
from core.modules.losses import L2Loss, PenaltyLoss
from core.modules.subnetworks import Encoder, Decoder, Predictor
from core.utils.general_utils import AttrDict, ParamDict, batch_apply, remove_spatial
from core.modules.layers import LayerBuilderParams
from core.modules.variational_inference import NoClampMultivariateGaussian
from core.utils.vis_utils import make_image_strip


class DeepMDPModel(BaseModel):
    """Latent prediction representation learning model, based on DeepMDP (Gelada et al., ICML 2019)."""
    def __init__(self, params, logger):
        super().__init__(logger)
        self._hp = self._default_hparams()
        self._hp.overwrite(params)
        self._hp.builder = LayerBuilderParams(self._hp.use_convs, self._hp.normalization)

        self._task_names = [task_name().name if not isinstance(task_name, str) else task_name
                            for task_name in self._hp.task_names]

        self.build_network()

    def _default_hparams(self):
        default_dict = ParamDict({
            'use_skips': True,
            'skips_stride': 1,
            'add_weighted_pixel_copy': False,  # if True, adds pixel copying stream for decoder
            'pixel_shift_decoder': False,
            'use_convs': True,
            'normalization': 'batch',
            'discount': 0.9,            # discount used to compute bisimilarity
        })

        # Network size
        default_dict.update({
            'action_dim': -1,           # dimensionality of the action input
            'img_sz': 64,               # resolution of the input images
            'input_nc': 3,              # number of channels of input images
            'ngf': 8,                   # number of channels in input layer of encoder --> gets doubled every layer
            'nz_enc': 128,              # representation latent size
            'nz_mid': 128,              # dimensionality of intermediate fully connected layers
            'n_processing_layers': 2,   # number of hidden layers in non-conv nets
        })

        # Loss weights
        default_dict.update({
            'latent_pred_weight': 1.,         # weight of latent prediction loss component
            'reward_pred_weight': 1.,         # weight of reward prediction loss component
            'rec_weight': 0.,                 # weight of observation reconstruction loss component
        })
        parent_params = super()._default_hparams()
        parent_params.overwrite(default_dict)
        return parent_params

    def build_network(self):
        self.encoder = Encoder(self._hp)
        self.decoder = Decoder(self._hp)
        pred_input_size = self._hp.nz_enc + self._hp.action_dim
        self.pred_mdl = Predictor(self._updated_mlp_params(),
                                  input_size=pred_input_size,
                                  output_size=2*self._hp.nz_enc)
        self.reward_mdls = nn.ModuleDict({name: Predictor(self._updated_mlp_params(), input_size=pred_input_size,
                                                          output_size=1) for name in self._task_names})

    def forward(self, inputs):
        output = AttrDict()
        assert inputs.images.shape[1] == 2      # we need two consecutive states for training of predictive model

        # encode inputs
        z, skips = self.encoder(inputs.images[:, 0])
        z_prime, _ = self.encoder(inputs.images[:, 1])
        pred_input = torch.cat((remove_spatial(z), inputs.actions[:, 0]), dim=-1)

        # predict next state
        output.z_prime = remove_spatial(z_prime)
        output.z_prime_hat = NoClampMultivariateGaussian(self.pred_mdl(pred_input))

        # compute reward prediction
        output.reward_pred = AttrDict({name: self.reward_mdls[name](pred_input) for name in self._task_names})

        # decode observation
        output.reconstruction = self.decoder(output.z_prime_hat.rsample()[..., None, None], skips=skips).images

        return output

    def loss(self, model_output, inputs):
        losses = AttrDict()

        # predictive losses for transition and reward models
        losses.transition_loss = PenaltyLoss(self._hp.latent_pred_weight)(self._compute_trans_loss(model_output.z_prime_hat,
                                                                                            model_output.z_prime))
        losses.update(AttrDict({'reward_loss_' + name: L2Loss(self._hp.reward_pred_weight)
                                (model_output.reward_pred[name][:, 0][inputs.task_id == i],
                                 inputs.rewards[:, 0][inputs.task_id == i]) for i, name in enumerate(self._task_names)}))

        # reconstruction_loss
        losses.rec_loss = L2Loss(self._hp.rec_weight)(model_output.reconstruction, inputs.images[:, 1])

        losses.total = self._compute_total_loss(losses)
        return losses

    def log_outputs(self, model_output, inputs, losses, step, log_images, phase):
        super()._log_losses(losses, step, log_images, phase)
        if log_images:
            # log reconstructions
            img_strip = make_image_strip([inputs.images[:, 1, :3], model_output.reconstruction[:, :3]])
            self._logger.log_images(img_strip[None], "reconstruction", step, phase)

    def forward_encoder(self, inputs):
        enc = self.encoder(inputs)
        return enc

    def _compute_trans_loss(self, z_prime_hat, trans_targets):
        diff = (z_prime_hat.mu - trans_targets.detach()) / z_prime_hat.sigma
        return torch.mean(0.5 * diff.pow(2) + torch.log(z_prime_hat.sigma))

    def _updated_mlp_params(self):
        params = copy.deepcopy(self._hp)
        return params.overwrite(AttrDict(
            use_convs=False,
            builder=LayerBuilderParams(use_convs=False, normalization=self._hp.normalization)
        ))

    @property
    def resolution(self):
        return self._hp.img_sz

    @contextmanager
    def val_mode(self):
        pass
        yield
        pass

